import os
import subprocess as sp
import time

import torch

import utils


class Trainer:
    def __init__(self, full_package, config):
        self.loop_type = "default"
        self.full_package = full_package
        self.config = config
        self.initialize_attributes()

    def initialize_attributes(self):
        self.full_package["it_global"] = 0
        self.full_package["epoch"] = 0
        self.full_package["running_vars"] = {
            "best_test_acc": 0.0,
            "best_test_epoch": 0,
            "best_eval_acc": 0.0,
            "best_eval_epoch": 0,
        }
        self.set_attributes(
            [
                "trainloader",
                "valloader",
                "testloader",
                "model",
                "criterion",
                "optimizer",
                "device",
                "config",
                "epoch",
            ],
            self.full_package,
        )
        self.full_package["iterations_per_epoch"] = len(self.trainloader)

    def run(self):
        print("==> Start training..")
        while True:
            self.train_loop(self.trainloader)
            self.save_and_evaluate()
            self.full_package["scheduler"].step()
            if self.is_debug_mode():
                break
            if not self.should_continue_training():
                break
        self.finishing_up()

    def get_gpu_memory(self):
        command = "nvidia-smi --query-gpu=memory.free --format=csv"
        memory_free_info = (
            sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
        )
        memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
        return memory_free_values

    def train_loop(self, trainloader):
        assert (
            self.config["train"]["loop_type"] == self.loop_type
        ), "Only default loop type is supported for now."
        self.init_epoch_variables()
        self.model.train()

        print("==> Start training epoch {}".format(self.full_package["epoch"]))
        for it, data in enumerate(trainloader):
            loss, correct = self.train_step(data)
            self.update_running_vars(it, loss, correct, data)

            if self.should_log_loss():
                self.log_loss_train()
            if self.is_debug_mode():
                break

        self.update_global_progress()
        # return total_loss / len(self.trainloader)

    def train_step(self, data):
        inputs, labels = self.prepare_data(data)
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()

        correct = (outputs.argmax(1) == labels).sum().item()

        return loss, correct

    # Evaluation
    def eval_loop(self, loader, continue_training=False):
        self.model.eval()
        correct, total = 0.0, 0.0
        for i, data in enumerate(loader):
            inputs, labels = self.prepare_data(data)
            with torch.no_grad():
                outputs = self.model(inputs)
            # loss += criterion(outputs, labels).item()
            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
            if self.config["test"]["debug"] and i >= 5:
                break

        if continue_training:
            self.model.train()

        acc = correct / total * 100
        return acc

    def finishing_up(self):
        # Done training
        self.save_final_stats()

        for key in self.full_package["summary_writers"]:
            self.full_package["summary_writers"][key].close()
        print("Finished Training")

    ################## Helper functions ##################
    def set_attributes(self, keys, full_package):
        for key in keys:
            assert key in full_package, f"Key {key} not in full_package."
            setattr(self, key, self.full_package[key])

    # High level
    def should_continue_training(self):
        epoch_condition = self.full_package["epoch"] < self.config["train"]["epochs"]
        iteration_condition = (
            self.full_package["it_global"] < self.config["train"]["global_iteration"]
        )
        return epoch_condition or iteration_condition

    def is_debug_mode(self):
        if self.config["train"]["debug"]:
            return (
                self.full_package["running_vars"]["it"] >= 100
                or self.full_package["epoch"] >= 5
            )
        else:
            return False

    # Epoch level
    def update_running_vars(self, iteration, loss, correct, data):
        total_loss = (
            self.full_package["running_vars"].get("total_loss", 0.0) + loss.item()
        )
        vars_to_update = {
            "it": iteration,
            "cur_loss": loss,
            "total_loss": total_loss,
            "correct": correct,
            "total": data[1].size(0),
        }
        vars_to_update["acc"] = (
            vars_to_update["correct"] / vars_to_update["total"] * 100
        )
        self.full_package["running_vars"].update(vars_to_update)

    def update_global_progress(self):
        self.full_package["it_global"] += self.full_package["iterations_per_epoch"]
        self.full_package["epoch"] += 1

    # Iteration level
    def prepare_data(self, data):
        inputs, labels = data[0].to(self.device), data[1].to(self.device)
        inputs, labels = inputs.float(), labels.long()
        if len(labels.shape) > 1:
            labels = torch.squeeze(labels)
        return inputs, labels

    def init_epoch_variables(self):
        self.full_package["running_vars"]["t0"] = time.time()

    # Log information
    # Log: training loss
    def should_log_loss(self):
        it = self.full_package["running_vars"]["it"]
        print_every = self.config["train"]["print_every"]
        return print_every > 0 and it % print_every == print_every - 1

    def log_loss_train(self):
        it_global = self.get_global_iteration()
        loss = self.full_package["running_vars"]["cur_loss"].item()
        total_loss = self.full_package["running_vars"]["total_loss"]
        time_elapsed = time.time() - self.full_package["running_vars"]["t0"]
        train_acc = self.full_package["running_vars"]["acc"]

        self.full_package["summary_writers"]["train"].add_scalar(
            "loss/train", loss, global_step=it_global
        )
        self.full_package["summary_writers"]["train"].add_scalar(
            "acc/train", train_acc, global_step=it_global
        )

        free_gpu_mem = self.get_gpu_memory()[0]
        current_lr = self.optimizer.param_groups[0]["lr"]

        info_str = f'It: {it_global}, Train_acc: {train_acc:.5f},loss: {total_loss / self.config["train"]["print_every"]:.5f}, time elapsed: {time_elapsed:.3f}, free gpu: {free_gpu_mem:.3f}, lr: {current_lr:.3f}'
        print(info_str)
        self.full_package["summary_writers"]["fp_log_res"].write(info_str + "\n")
        self.full_package["summary_writers"]["fp_log_res"].flush()
        self.full_package["running_vars"]["total_loss"] = 0.0

    # Log: evaluation
    def log_loss_eval(self, acc, log_type="eval"):

        if acc > self.full_package["running_vars"][f"best_{log_type}_acc"]:
            self.full_package["running_vars"][f"best_{log_type}_acc"] = acc
            self.full_package["running_vars"][f"best_{log_type}_epoch"] = (
                self.full_package["epoch"]
            )
            utils.save_train_package(
                self.full_package,
                self.config["general"]["save_model_dir"]
                + f"/train_package_{log_type}_best.pth",
                self.keys_to_save(),
            )

        it_global = self.get_global_iteration()
        self.full_package["summary_writers"][log_type].add_scalar(
            f"acc/{log_type}", acc, global_step=it_global
        )
        self.full_package["summary_writers"][log_type].add_scalar(
            f"acc/best_{log_type}",
            self.full_package["running_vars"][f"best_{log_type}_acc"],
            global_step=it_global,
        )

        info_str = f'{log_type}_acc: {acc:.5f}, best_{log_type}_acc: {self.full_package["running_vars"][f"best_{log_type}_acc"]:.5f}'
        print(info_str)
        self.full_package["summary_writers"]["fp_log_res"].write(info_str + "\n")
        self.full_package["summary_writers"]["fp_log_res"].flush()

    def save_and_evaluate(self):
        # Save checkpoint
        utils.save_train_package(
            self.full_package,
            self.config["general"]["save_model_dir"] + "/train_package_last.pth",
            self.keys_to_save(),
        )

        # Evaluate on validation set
        eval_acc = self.eval_loop(self.valloader, continue_training=True)
        self.log_loss_eval(eval_acc, log_type="eval")

        test_acc = self.eval_loop(self.testloader, continue_training=True)
        self.log_loss_eval(test_acc, log_type="test")

    def keys_to_save(self):
        return [
            "config",
            "model",
            "optimizer",
            "epoch",
            "running_vars",
            "imbalance_info",
        ]

    def get_global_iteration(self):
        it = self.full_package["running_vars"]["it"]
        return (
            self.full_package["epoch"] * self.full_package["iterations_per_epoch"] + it
        )

    # Log: save final stats:
    def save_final_stats(self):
        for key in self.full_package["running_vars"]:
            self.full_package["summary_writers"]["stat_log_res"].write(
                f'{key}: {self.full_package["running_vars"][key]}\n'
            )
            self.full_package["summary_writers"]["stat_log_res"].flush()
